/* -*- c -*- */
#define PY_SSIZE_T_CLEAN
#include <Python.h>

#define _UMATHMODULE
#define _MULTIARRAYMODULE
#define NPY_NO_DEPRECATED_API NPY_API_VERSION

#include "npy_config.h"
#include "numpy/npy_common.h"
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
#include "numpy/npy_math.h"
#include "numpy/halffloat.h"
#include "lowlevel_strided_loops.h"



#include "blas_utils.h"
#include "npy_cblas.h"
#include "arraytypes.h" /* For TYPE_dot functions */

#include <assert.h>

/*
 *****************************************************************************
 **                            BASICS                                       **
 *****************************************************************************
 */

#define ABS(x) ((x) < 0 ? -(x) : (x))

#if defined(HAVE_CBLAS)
/*
 * -1 to be conservative, in case blas internally uses a for loop with an
 * inclusive upper bound
 */
#ifndef HAVE_BLAS_ILP64
#define BLAS_MAXSIZE (NPY_MAX_INT - 1)
#else
#define BLAS_MAXSIZE (NPY_MAX_INT64 - 1)
#endif

/*
 * Determine if a 2d matrix can be used by BLAS
 * 1. Strides must not alias or overlap
 * 2. The faster (second) axis must be contiguous
 * 3. The slower (first) axis stride, in unit steps, must be larger than
 *    the faster axis dimension
 */
static inline npy_bool
is_blasable2d(npy_intp byte_stride1, npy_intp byte_stride2,
              npy_intp d1, npy_intp d2,  npy_intp itemsize)
{
    npy_intp unit_stride1 = byte_stride1 / itemsize;
    if (byte_stride2 != itemsize) {
        return NPY_FALSE;
    }
    if ((byte_stride1 % itemsize ==0) &&
        (unit_stride1 >= d2) &&
        (unit_stride1 <= BLAS_MAXSIZE))
    {
        return NPY_TRUE;
    }
    return NPY_FALSE;
}

#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
static const npy_cdouble oneD = {1.0, 0.0}, zeroD = {0.0, 0.0};
static const npy_cfloat oneF = {1.0f, 0.0f}, zeroF = {0.0f, 0.0f};
#else
static const npy_cdouble oneD = 1.0, zeroD = 0.0;
static const npy_cfloat oneF = 1.0f, zeroF = 0.0f;
#endif

/**begin repeat
 *
 * #name = FLOAT, DOUBLE, CFLOAT, CDOUBLE#
 * #ctype = npy_float, npy_double, npy_cfloat, npy_cdouble#
 * #typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
 * #prefix = s, d, c, z#
 * #step1 = 1.F, 1., &oneF, &oneD#
 * #step0 = 0.F, 0., &zeroF, &zeroD#
 */

static inline void
@name@_matrix_copy(npy_bool transpose,
                   void *_ip, npy_intp is_m, npy_intp is_n,
                   void *_op, npy_intp os_m, npy_intp os_n,
                   npy_intp dm, npy_intp dn)
{

    char *ip = (char *)_ip, *op = (char *)_op;

    npy_intp m, n, ib, ob;
    
    if (transpose) {
        ib = is_m * dm, ob = os_m * dm;

        for (n = 0; n < dn; n++) {
            for (m = 0; m < dm; m++) {
                *(@ctype@ *)op = *(@ctype@ *)ip;
                ip += is_m;
                op += os_m;
            }
            ip += is_n - ib;
            op += os_n - ob;
        }
    
        return;
    }

    ib = is_n * dn, ob = os_n * dn;

    for (m = 0; m < dm; m++) {
        for (n = 0; n < dn; n++) {
            *(@ctype@ *)op = *(@ctype@ *)ip;
            ip += is_n;
            op += os_n;
        }
        ip += is_m - ib;
        op += os_m - ob;
    }
}

static void
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
            void *ip2, npy_intp is2_n,
            void *op, npy_intp op_m,
            npy_intp m, npy_intp n)
{
    /*
     * Vector matrix multiplication -- Level 2 BLAS
     * arguments
     * ip1: contiguous data, m*n shape
     * ip2: data in c order, n*1 shape
     * op:  data in c order, m shape
     */
    enum CBLAS_ORDER order;
    CBLAS_INT M, N, lda;

    assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE);
    assert (is_blasable2d(is2_n, sizeof(@typ@), n, 1, sizeof(@typ@)));
    M = (CBLAS_INT)m;
    N = (CBLAS_INT)n;

    if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
        order = CblasColMajor;
        lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
    }
    else {
        /* If not ColMajor, caller should have ensured we are RowMajor */
        /* will not assert in release mode */
        order = CblasRowMajor;
        assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
        lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
    }
    CBLAS_FUNC(cblas_@prefix@gemv)(order, CblasTrans, N, M, @step1@, ip1, lda, ip2,
                                     is2_n / sizeof(@typ@), @step0@, op, op_m / sizeof(@typ@));
}

static void
@name@_matmul_matrixmatrix(void *ip1, npy_intp is1_m, npy_intp is1_n,
                           void *ip2, npy_intp is2_n, npy_intp is2_p,
                           void *op, npy_intp os_m, npy_intp os_p,
                           npy_intp m, npy_intp n, npy_intp p)
{
    /*
     * matrix matrix multiplication -- Level 3 BLAS
     */
    enum CBLAS_ORDER order = CblasRowMajor;
    enum CBLAS_TRANSPOSE trans1, trans2;
    CBLAS_INT M, N, P, lda, ldb, ldc;
    assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE && p <= BLAS_MAXSIZE);
    M = (CBLAS_INT)m;
    N = (CBLAS_INT)n;
    P = (CBLAS_INT)p;

    assert(is_blasable2d(os_m, os_p, m, p, sizeof(@typ@)));
    ldc = (CBLAS_INT)(os_m / sizeof(@typ@));

    if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
        trans1 = CblasNoTrans;
        lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
    }
    else {
        /* If not ColMajor, caller should have ensured we are RowMajor */
        /* will not assert in release mode */
        assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
        trans1 = CblasTrans;
        lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
    }

    if (is_blasable2d(is2_n, is2_p, n, p, sizeof(@typ@))) {
        trans2 = CblasNoTrans;
        ldb = (CBLAS_INT)(is2_n / sizeof(@typ@));
    }
    else {
        /* If not ColMajor, caller should have ensured we are RowMajor */
        /* will not assert in release mode */
        assert(is_blasable2d(is2_p, is2_n, p, n, sizeof(@typ@)));
        trans2 = CblasTrans;
        ldb = (CBLAS_INT)(is2_p / sizeof(@typ@));
    }
    /*
     * Use syrk if we have a case of a matrix times its transpose.
     * Otherwise, use gemm for all other cases.
     */
    if (
        (ip1 == ip2) &&
        (m == p) &&
        (is1_m == is2_p) &&
        (is1_n == is2_n) &&
        (trans1 != trans2)
    ) {
        npy_intp i,j;
        if (trans1 == CblasNoTrans) {
            CBLAS_FUNC(cblas_@prefix@syrk)(
                order, CblasUpper, trans1, P, N, @step1@,
                ip1, lda, @step0@, op, ldc);
        }
        else {
            CBLAS_FUNC(cblas_@prefix@syrk)(
                order, CblasUpper, trans1, P, N, @step1@,
                ip1, ldb, @step0@, op, ldc);
        }
        /* Copy the triangle */
        for (i = 0; i < P; i++) {
            for (j = i + 1; j < P; j++) {
                ((@typ@*)op)[j * ldc + i] = ((@typ@*)op)[i * ldc + j];
            }
        }

    }
    else {
        CBLAS_FUNC(cblas_@prefix@gemm)(
            order, trans1, trans2, M, P, N, @step1@, ip1, lda,
            ip2, ldb, @step0@, op, ldc);
    }
}

/**end repeat**/
#endif

/*
 * matmul loops
 * signature is (m?,n),(n,p?)->(m?,p?)
 */

/**begin repeat
 *  #TYPE = LONGDOUBLE,
 *          FLOAT, DOUBLE, HALF,
 *          CFLOAT, CDOUBLE, CLONGDOUBLE,
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
 *          BYTE, SHORT, INT, LONG, LONGLONG#
 *  #typ = npy_longdouble,
 *         npy_float,npy_double,npy_half,
 *         npy_cfloat, npy_cdouble, npy_clongdouble,
 *         npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong#
 * #suff = , , , , f, , l, x*10#
 * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*10#
 * #IS_HALF = 0, 0, 0, 1, 0*13#
 */

static void
@TYPE@_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
                           void *_ip2, npy_intp is2_n, npy_intp is2_p,
                           void *_op, npy_intp os_m, npy_intp os_p,
                           npy_intp dm, npy_intp dn, npy_intp dp)

{
    npy_intp m, n, p;
    npy_intp ib1_n, ib2_n, ib2_p, ob_p;
    char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;

    ib1_n = is1_n * dn;
    ib2_n = is2_n * dn;
    ib2_p = is2_p * dp;
    ob_p  = os_p * dp;

    for (m = 0; m < dm; m++) {
        for (p = 0; p < dp; p++) {
#if @IS_COMPLEX@ == 1
            npy_csetreal@suff@((@typ@ *)op, 0.0);
            npy_csetimag@suff@((@typ@ *)op, 0.0);
#elif @IS_HALF@
            float sum = 0;
#else
            *(@typ@ *)op = 0;
#endif
            for (n = 0; n < dn; n++) {
                @typ@ val1 = (*(@typ@ *)ip1);
                @typ@ val2 = (*(@typ@ *)ip2);
#if @IS_HALF@
                sum += npy_half_to_float(val1) * npy_half_to_float(val2);
#elif @IS_COMPLEX@ == 1
                npy_csetreal@suff@((@typ@ *) op, npy_creal@suff@(*((@typ@ *) op)) + ((npy_creal@suff@(val1) * npy_creal@suff@(val2)) -
                                       (npy_cimag@suff@(val1) * npy_cimag@suff@(val2))));
                npy_csetimag@suff@((@typ@ *) op, npy_cimag@suff@(*((@typ@ *) op)) + ((npy_creal@suff@(val1) * npy_cimag@suff@(val2)) +
                                       (npy_cimag@suff@(val1) * npy_creal@suff@(val2))));
#else
                *(@typ@ *)op += val1 * val2;
#endif
                ip2 += is2_n;
                ip1 += is1_n;
            }
#if @IS_HALF@
            *(@typ@ *)op = npy_float_to_half(sum);
#endif
            ip1 -= ib1_n;
            ip2 -= ib2_n;
            op  +=  os_p;
            ip2 += is2_p;
        }
        op -= ob_p;
        ip2 -= ib2_p;
        ip1 += is1_m;
        op  +=  os_m;
    }
}

/**end repeat**/
static void
BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
                           void *_ip2, npy_intp is2_n, npy_intp is2_p,
                           void *_op, npy_intp os_m, npy_intp os_p,
                           npy_intp dm, npy_intp dn, npy_intp dp)

{
    npy_intp m, n, p;
    npy_intp ib2_p, ob_p;
    char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;

    ib2_p = is2_p * dp;
    ob_p  = os_p * dp;

    for (m = 0; m < dm; m++) {
        for (p = 0; p < dp; p++) {
            char *ip1tmp = ip1;
            char *ip2tmp = ip2;
            *(npy_bool *)op = NPY_FALSE;
            for (n = 0; n < dn; n++) {
                npy_bool val1 = (*(npy_bool *)ip1tmp);
                npy_bool val2 = (*(npy_bool *)ip2tmp);
                if (val1 != 0 && val2 != 0) {
                    *(npy_bool *)op = NPY_TRUE;
                    break;
                }
                ip2tmp += is2_n;
                ip1tmp += is1_n;
            }
            op  +=  os_p;
            ip2 += is2_p;
        }
        op -= ob_p;
        ip2 -= ib2_p;
        ip1 += is1_m;
        op  +=  os_m;
    }
}

static void
OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
                           void *_ip2, npy_intp is2_n, npy_intp is2_p,
                           void *_op, npy_intp os_m, npy_intp os_p,
                           npy_intp dm, npy_intp dn, npy_intp dp)
{
    char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;

    npy_intp ib1_n = is1_n * dn;
    npy_intp ib2_n = is2_n * dn;
    npy_intp ib2_p = is2_p * dp;
    npy_intp ob_p  = os_p * dp;

    PyObject *product, *sum_of_products = NULL;

    for (npy_intp m = 0; m < dm; m++) {
        for (npy_intp p = 0; p < dp; p++) {
            if ( 0 == dn ) {
                sum_of_products = PyLong_FromLong(0);
                if (sum_of_products == NULL) {
                    return;
                }
            }

            for (npy_intp n = 0; n < dn; n++) {
                PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2;
                if (obj1 == NULL) {
                    obj1 = Py_None;
                }
                if (obj2 == NULL) {
                    obj2 = Py_None;
                }

                product = PyNumber_Multiply(obj1, obj2);
                if (product == NULL) {
                    Py_XDECREF(sum_of_products);
                    return;
                }

                if (n == 0) {
                    sum_of_products = product;
                }
                else {
                    Py_SETREF(sum_of_products, PyNumber_Add(sum_of_products, product));
                    Py_DECREF(product);
                    if (sum_of_products == NULL) {
                        return;
                    }
                }

                ip2 += is2_n;
                ip1 += is1_n;
            }

            *((PyObject **)op) = sum_of_products;
            ip1 -= ib1_n;
            ip2 -= ib2_n;
            op  +=  os_p;
            ip2 += is2_p;
        }
        op -= ob_p;
        ip2 -= ib2_p;
        ip1 += is1_m;
        op  +=  os_m;
    }
}


/**begin repeat
 *  #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
 *          CFLOAT, CDOUBLE, CLONGDOUBLE,
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
 *          BYTE, SHORT, INT, LONG, LONGLONG,
 *          BOOL, OBJECT#
 *  #typ = npy_float,npy_double,npy_longdouble, npy_half,
 *         npy_cfloat, npy_cdouble, npy_clongdouble,
 *         npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong,
 *         npy_bool,npy_object#
 * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*12#
 * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13#
 */


NPY_NO_EXPORT void
@TYPE@_matmul(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
    npy_intp dOuter = *dimensions++;
    npy_intp iOuter;
    npy_intp s0 = *steps++;
    npy_intp s1 = *steps++;
    npy_intp s2 = *steps++;
    npy_intp dm = dimensions[0];
    npy_intp dn = dimensions[1];
    npy_intp dp = dimensions[2];
    npy_intp is1_m=steps[0], is1_n=steps[1], is2_n=steps[2], is2_p=steps[3],
         os_m=steps[4], os_p=steps[5];
#if @USEBLAS@ && defined(HAVE_CBLAS)
    npy_intp sz = sizeof(@typ@);
    npy_bool special_case = (dm == 1 || dn == 1 || dp == 1);
    npy_bool any_zero_dim = (dm == 0 || dn == 0 || dp == 0);
    npy_bool scalar_out = (dm == 1 && dp == 1);
    npy_bool scalar_vec = (dn == 1 && (dp == 1 || dm == 1));
    npy_bool too_big_for_blas = (dm > BLAS_MAXSIZE || dn > BLAS_MAXSIZE ||
                                 dp > BLAS_MAXSIZE);
    npy_bool i1_c_blasable = is_blasable2d(is1_m, is1_n, dm, dn, sz);
    npy_bool i2_c_blasable = is_blasable2d(is2_n, is2_p, dn, dp, sz);
    npy_bool i1_f_blasable = is_blasable2d(is1_n, is1_m, dn, dm, sz);
    npy_bool i2_f_blasable = is_blasable2d(is2_p, is2_n, dp, dn, sz);
    npy_bool i1blasable = i1_c_blasable || i1_f_blasable;
    npy_bool i2blasable = i2_c_blasable || i2_f_blasable;
    npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz);
    npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz);
    npy_bool oblasable = o_c_blasable || o_f_blasable;
    npy_bool vector_matrix = ((dm == 1) && i2blasable &&
                              is_blasable2d(is1_n, sz, dn, 1, sz));
    npy_bool matrix_vector = ((dp == 1)  && i1blasable &&
                              is_blasable2d(is2_n, sz, dn, 1, sz));
    npy_bool noblas_fallback = too_big_for_blas || any_zero_dim;
    npy_bool matrix_matrix = !noblas_fallback && !special_case;
    npy_bool allocate_buffer = matrix_matrix && (
        !i1blasable || !i2blasable || !oblasable
    );

    uint8_t *tmp_ip12op = NULL;
    void *tmp_ip1 = NULL, *tmp_ip2 = NULL, *tmp_op = NULL;

    if (allocate_buffer){
        npy_intp ip1_size = i1blasable ? 0 : sz * dm * dn,
                 ip2_size = i2blasable ? 0 : sz * dn * dp,
                 op_size = oblasable ? 0 : sz * dm * dp,
                 total_size = ip1_size + ip2_size + op_size;

        tmp_ip12op = (uint8_t*)malloc(total_size);

        if (tmp_ip12op == NULL) {
            PyGILState_STATE gil_state = PyGILState_Ensure();
            PyErr_SetString(
                PyExc_MemoryError, "Out of memory in matmul"
            );
            PyGILState_Release(gil_state);

            return;
        }    
    
        tmp_ip1 = tmp_ip12op;
        tmp_ip2 = tmp_ip12op + ip1_size;
        tmp_op = tmp_ip12op + ip1_size + ip2_size;    
    }

#endif

    for (iOuter = 0; iOuter < dOuter; iOuter++,
                         args[0] += s0, args[1] += s1, args[2] += s2) {
        void *ip1=args[0], *ip2=args[1], *op=args[2];
#if @USEBLAS@ && defined(HAVE_CBLAS)
        /*
         * TODO: refactor this out to a inner_loop_selector, in
         * PyUFunc_MatmulLoopSelector. But that call does not have access to
         * n, m, p and strides.
         */
        if (noblas_fallback) {
            @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
                                       ip2, is2_n, is2_p,
                                       op, os_m, os_p, dm, dn, dp);
        }
        else if (special_case) {
            /* Special case variants that have a 1 in the core dimensions */
            if (scalar_out) {
                /* row @ column, 1,1 output */
                @TYPE@_dot(ip1, is1_n, ip2, is2_n, op, dn, NULL);
            } else if (scalar_vec){
                /*
                 * 1,1d @ vector or vector @ 1,1d
                 * could use cblas_Xaxy, but that requires 0ing output
                 * and would not be faster (XXX prove it)
                 */
                @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
                                           ip2, is2_n, is2_p,
                                           op, os_m, os_p, dm, dn, dp);
            } else if (vector_matrix) {
                /* vector @ matrix, switch ip1, ip2, p and m */
                @TYPE@_gemv(ip2, is2_p, is2_n, ip1, is1_n,
                            op, os_p, dp, dn);
            } else if  (matrix_vector) {
                /* matrix @ vector */
                @TYPE@_gemv(ip1, is1_m, is1_n, ip2, is2_n,
                            op, os_m, dm, dn);
            } else {
                /* column @ row, 2d output, no blas needed or non-blas-able input */
                @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
                                           ip2, is2_n, is2_p,
                                           op, os_m, os_p, dm, dn, dp);
            }
        } else {
            /* matrix @ matrix 
             * copy if not blasable, see gh-12365 & gh-23588 */
            npy_bool i1_transpose = ABS(is1_m) < ABS(is1_n),
                     i2_transpose = ABS(is2_n) < ABS(is2_p),
                     o_transpose = ABS(os_m) < ABS(os_p);

            npy_intp tmp_is1_m = i1_transpose ? sz : sz*dn,
                     tmp_is1_n = i1_transpose ? sz*dm : sz,
                     tmp_is2_n = i2_transpose ? sz : sz*dp,
                     tmp_is2_p = i2_transpose ? sz*dn : sz,
                     tmp_os_m = o_transpose ? sz : sz*dp,
                     tmp_os_p = o_transpose ? sz*dm : sz;

            if (!i1blasable) {
                @TYPE@_matrix_copy(
                    i1_transpose, ip1, is1_m, is1_n,
                    tmp_ip1, tmp_is1_m, tmp_is1_n,
                    dm, dn
                );
            }
            
            if (!i2blasable) {
                @TYPE@_matrix_copy(
                    i2_transpose, ip2, is2_n, is2_p,
                    tmp_ip2, tmp_is2_n, tmp_is2_p,
                    dn, dp
                );
            }

            void *ip1_ = i1blasable ? ip1 : tmp_ip1,
                 *ip2_ = i2blasable ? ip2 : tmp_ip2,
                 *op_ = oblasable ? op : tmp_op;

            npy_intp is1_m_ = i1blasable ? is1_m : tmp_is1_m,
                     is1_n_ = i1blasable ? is1_n : tmp_is1_n,
                     is2_n_ = i2blasable ? is2_n : tmp_is2_n,
                     is2_p_ = i2blasable ? is2_p : tmp_is2_p,
                     os_m_ = oblasable ? os_m : tmp_os_m,
                     os_p_ = oblasable ? os_p : tmp_os_p;

            /*
             * Use transpose equivalence:
             * matmul(a, b, o) == matmul(b.T, a.T, o.T)
             */
            if (o_transpose) {
                @TYPE@_matmul_matrixmatrix(
                    ip2_, is2_p_, is2_n_,
                    ip1_, is1_n_, is1_m_,
                    op_, os_p_, os_m_,
                    dp, dn, dm
                );
            }
            else {
                @TYPE@_matmul_matrixmatrix(
                    ip1_, is1_m_, is1_n_,
                    ip2_, is2_n_, is2_p_,
                    op_, os_m_, os_p_,
                    dm, dn, dp
                );
            }

            if(!oblasable){
                @TYPE@_matrix_copy(
                    o_transpose, tmp_op, tmp_os_m, tmp_os_p,
                    op, os_m, os_p,
                    dm, dp
                );
            }
        }
#else
        @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
                                   ip2, is2_n, is2_p,
                                   op, os_m, os_p, dm, dn, dp);

#endif
    }
#if @USEBLAS@ && defined(HAVE_CBLAS)
#if NPY_BLAS_CHECK_FPE_SUPPORT
    if (!npy_blas_supports_fpe()) {
        npy_clear_floatstatus_barrier((char*)args);
    }
#endif
    if (allocate_buffer) free(tmp_ip12op);
#endif
}

/**end repeat**/


/*
 * vecdot loops
 * signature is (n),(n)->()
 * We need to redefine complex and object dot from arraytypes.c.src to ensure
 * we take the complex conjugate of the first argument (for cblas, this means
 * using the dotc functions instead of dotu).
 */

/**begin repeat
 *
 * #name = CFLOAT, CDOUBLE, CLONGDOUBLE#
 * #ctype = npy_cfloat, npy_cdouble, npy_clongdouble#
 * #type = npy_float, npy_double, npy_longdouble#
 * #sumtype = npy_double, npy_double, npy_longdouble#
 * #prefix = c, z, 0#
 * #USE_BLAS = 1, 1, 0#
 */
static void
@name@_dotc(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
            char *op, npy_intp n, void *NPY_UNUSED(ignore))
{
#if defined(HAVE_CBLAS) && @USE_BLAS@
    CBLAS_INT is1b = blas_stride(is1, sizeof(@ctype@));
    CBLAS_INT is2b = blas_stride(is2, sizeof(@ctype@));

    if (is1b && is2b) {
        @sumtype@ sum[2] = {0., 0.};  /* at least double for stability */

        while (n > 0) {
            CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
            @type@ tmp[2];

            CBLAS_FUNC(cblas_@prefix@dotc_sub)(
                    (CBLAS_INT)chunk, ip1, is1b, ip2, is2b, tmp);
            sum[0] += (@sumtype@)tmp[0];
            sum[1] += (@sumtype@)tmp[1];
            /* use char strides here */
            ip1 += chunk * is1;
            ip2 += chunk * is2;
            n -= chunk;
        }
        ((@type@ *)op)[0] = (@type@)sum[0];
        ((@type@ *)op)[1] = (@type@)sum[1];
    }
    else
#endif
    {
        @sumtype@ sumr = 0., sumi = 0.;  /* at least double for stability */
        for (npy_intp i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
            const @type@ ip1r = ((@type@ *)ip1)[0];
            const @type@ ip1i = ((@type@ *)ip1)[1];
            const @type@ ip2r = ((@type@ *)ip2)[0];
            const @type@ ip2i = ((@type@ *)ip2)[1];
            /* ip1.conj() * ip2 */
            sumr += (@sumtype@)(ip1r * ip2r + ip1i * ip2i);
            sumi += (@sumtype@)(ip1r * ip2i - ip1i * ip2r);
        }
        ((@type@ *)op)[0] = (@type@)sumr;
        ((@type@ *)op)[1] = (@type@)sumi;
    }
}
/**end repeat**/

static void
OBJECT_dotc(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
            void *NPY_UNUSED(ignore))
{
    npy_intp i;
    PyObject *result = NULL;

    for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
        PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2;
        if (obj1 == NULL) {
            obj1 = Py_None;  /* raise useful error message at conjugate */
        }
        if (obj2 == NULL) {
            obj2 = Py_None;  /* raise at multiply (unless obj1 handles None) */
        }
        PyObject *obj1_conj = PyObject_CallMethod(obj1, "conjugate", NULL);
        if (obj1_conj == NULL) {
            goto fail;
        }
        PyObject *prod = PyNumber_Multiply(obj1_conj, obj2);
        Py_DECREF(obj1_conj);
        if (prod == NULL) {
            goto fail;
        }
        if (i == 0) {
            result = prod;
        }
        else {
            Py_SETREF(result, PyNumber_Add(result, prod));
            Py_DECREF(prod);
            if (result == NULL) {
                goto fail;
            }
        }
    }
    Py_XSETREF(*((PyObject **)op), result);
    return;

  fail:
    Py_XDECREF(result);
    return;
}

/**begin repeat
 *  #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
 *          BYTE, SHORT, INT, LONG, LONGLONG, BOOL,
 *          CFLOAT, CDOUBLE, CLONGDOUBLE, OBJECT#
 *  #DOT = dot*15, dotc*4#
 *  #CHECK_PYERR = 0*18, 1#
 *  #CHECK_BLAS = 1*2, 0*13, 1*2, 0*2#
 */
NPY_NO_EXPORT void
@TYPE@_vecdot(char **args, npy_intp const *dimensions, npy_intp const *steps,
              void *NPY_UNUSED(func))
{
    npy_intp n_outer = dimensions[0];
    npy_intp s0=steps[0], s1=steps[1], s2=steps[2];
    npy_intp n_inner = dimensions[1];
    npy_intp is1=steps[3], is2=steps[4];

    for (npy_intp i = 0; i < n_outer; i++, args[0] += s0, args[1] += s1, args[2] += s2) {
        /*
         * TODO: use new API to select inner loop with get_loop and
         * improve error treatment.
         */
        char *ip1=args[0], *ip2=args[1], *op=args[2];
        @TYPE@_@DOT@(ip1, is1, ip2, is2, op, n_inner, NULL);
#if @CHECK_PYERR@
        if (PyErr_Occurred()) {
            return;
        }
#endif
    }
#if @CHECK_BLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
    if (!npy_blas_supports_fpe()) {
        npy_clear_floatstatus_barrier((char*)args);
    }
#endif
}
/**end repeat**/

#if defined(HAVE_CBLAS)
/*
 * Blas complex vector-matrix product via gemm (gemv cannot conjugate the vector).
 */
/**begin repeat
 *
 * #name = CFLOAT, CDOUBLE#
 * #typ = npy_cfloat, npy_cdouble#
 * #prefix = c, z#
 * #step1 = &oneF, &oneD#
 * #step0 = &zeroF, &zeroD#
 */
static void
@name@_vecmat_via_gemm(void *ip1, npy_intp is1_n,
                       void *ip2, npy_intp is2_n, npy_intp is2_m,
                       void *op, npy_intp os_m,
                       npy_intp n, npy_intp m)
{
    enum CBLAS_ORDER order = CblasRowMajor;
    enum CBLAS_TRANSPOSE trans1, trans2;
    CBLAS_INT N, M, lda, ldb, ldc;
    assert(n <= BLAS_MAXSIZE && m <= BLAS_MAXSIZE);
    N = (CBLAS_INT)n;
    M = (CBLAS_INT)m;

    assert(os_m == sizeof(@typ@));
    ldc = (CBLAS_INT)m;

    assert(is_blasable2d(is1_n, sizeof(@typ@), n, 1, sizeof(@typ@)));
    trans1 = CblasConjTrans;
    lda = (CBLAS_INT)(is1_n / sizeof(@typ@));

    if (is_blasable2d(is2_n, is2_m, n, m, sizeof(@typ@))) {
        trans2 = CblasNoTrans;
        ldb = (CBLAS_INT)(is2_n / sizeof(@typ@));
    }
    else {
        assert(is_blasable2d(is2_m, is2_n, m, n, sizeof(@typ@)));
        trans2 = CblasTrans;
        ldb = (CBLAS_INT)(is2_m / sizeof(@typ@));
    }
    CBLAS_FUNC(cblas_@prefix@gemm)(
        order, trans1, trans2, 1, M, N, @step1@, ip1, lda,
        ip2, ldb, @step0@, op, ldc);
}
/**end repeat**/
#endif

/*
 * matvec loops, using blas gemv if possible, and TYPE_dot implementations otherwise.
 * signature is (m,n),(n)->(m)
 */
/**begin repeat
 *  #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
 *          CFLOAT, CDOUBLE, CLONGDOUBLE,
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
 *          BYTE, SHORT, INT, LONG, LONGLONG,
 *          BOOL, OBJECT#
 *  #typ = npy_float,npy_double,npy_longdouble, npy_half,
 *         npy_cfloat, npy_cdouble, npy_clongdouble,
 *         npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong,
 *         npy_bool, npy_object#
 * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13#
 * #CHECK_PYERR = 0*18, 1#
 */
NPY_NO_EXPORT void
@TYPE@_matvec(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
    npy_intp n_outer = dimensions[0];
    npy_intp s0=steps[0], s1=steps[1], s2=steps[2];
    npy_intp dm = dimensions[1], dn = dimensions[2];
    npy_intp is1_m=steps[3], is1_n=steps[4], is2_n=steps[5], os_m=steps[6];
#if @USEBLAS@ && defined(HAVE_CBLAS)
    npy_bool too_big_for_blas = (dm > BLAS_MAXSIZE || dn > BLAS_MAXSIZE);
    npy_bool i1_c_blasable = is_blasable2d(is1_m, is1_n, dm, dn, sizeof(@typ@));
    npy_bool i1_f_blasable = is_blasable2d(is1_n, is1_m, dn, dm, sizeof(@typ@));
    npy_bool i2_blasable = is_blasable2d(is2_n, sizeof(@typ@), dn, 1, sizeof(@typ@));
    npy_bool blasable = ((i1_c_blasable || i1_f_blasable) && i2_blasable
                         && !too_big_for_blas && dn > 1 && dm > 1);
#endif
    for (npy_intp i = 0; i < n_outer; i++,
             args[0] += s0, args[1] += s1, args[2] += s2) {
        char *ip1=args[0], *ip2=args[1], *op=args[2];
#if @USEBLAS@ && defined(HAVE_CBLAS)
        if (blasable) {
            @TYPE@_gemv(ip1, is1_m, is1_n, ip2, is2_n, op, os_m, dm, dn);
            continue;
        }
#endif
        /*
         * Dot the different matrix rows with the vector to get output elements.
         * (no conjugation for complex, unlike vecdot and vecmat)
         */
        for (npy_intp j = 0; j < dm; j++, ip1 += is1_m, op += os_m) {
            @TYPE@_dot(ip1, is1_n, ip2, is2_n, op, dn, NULL);
#if @CHECK_PYERR@
            if (PyErr_Occurred()) {
                return;
            }
#endif
        }
    }
#if @USEBLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
    if (!npy_blas_supports_fpe()) {
        npy_clear_floatstatus_barrier((char*)args);
    }
#endif
}
/**end repeat**/

/*
 * vecmat loops, using blas gemv for float and gemm for complex if possible,
 * and TYPE_dot[c] implementations otherwise.
 * Note that we cannot use gemv for complex, since we need to conjugate the vector.
 * signature is (n),(n,m)->(m)
 */
/**begin repeat
 *  #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
 *          CFLOAT, CDOUBLE, CLONGDOUBLE,
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
 *          BYTE, SHORT, INT, LONG, LONGLONG,
 *          BOOL, OBJECT#
 *  #typ = npy_float,npy_double,npy_longdouble, npy_half,
 *         npy_cfloat, npy_cdouble, npy_clongdouble,
 *         npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong,
 *         npy_bool, npy_object#
 * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13#
 * #COMPLEX = 0*4, 1*3, 0*11, 1#
 * #DOT = dot*4, dotc*3, dot*11, dotc#
 * #CHECK_PYERR = 0*18, 1#
 */
NPY_NO_EXPORT void
@TYPE@_vecmat(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
    npy_intp n_outer = dimensions[0];
    npy_intp s0=steps[0], s1=steps[1], s2=steps[2];
    npy_intp dn = dimensions[1], dm = dimensions[2];
    npy_intp is1_n=steps[3], is2_n=steps[4], is2_m=steps[5], os_m=steps[6];
#if @USEBLAS@ && defined(HAVE_CBLAS)
    npy_bool too_big_for_blas = (dm > BLAS_MAXSIZE || dn > BLAS_MAXSIZE);
    npy_bool i1_blasable = is_blasable2d(is1_n, sizeof(@typ@), dn, 1, sizeof(@typ@));
    npy_bool i2_c_blasable = is_blasable2d(is2_n, is2_m, dn, dm, sizeof(@typ@));
    npy_bool i2_f_blasable = is_blasable2d(is2_m, is2_n, dm, dn, sizeof(@typ@));
    npy_bool blasable = (i1_blasable && (i2_c_blasable || i2_f_blasable)
                         && !too_big_for_blas && dn > 1 && dm > 1);
#endif
    for (npy_intp i = 0; i < n_outer; i++,
             args[0] += s0, args[1] += s1, args[2] += s2) {
        char *ip1=args[0], *ip2=args[1], *op=args[2];
#if @USEBLAS@ && defined(HAVE_CBLAS)
        if (blasable) {
#if @COMPLEX@
            /* For complex, use gemm so we can conjugate the vector */
            @TYPE@_vecmat_via_gemm(ip1, is1_n, ip2, is2_n, is2_m, op, os_m, dn, dm);
#else
            /* For float, use gemv (hence flipped order) */
            @TYPE@_gemv(ip2, is2_m, is2_n, ip1, is1_n, op, os_m, dm, dn);
#endif
            continue;
        }
#endif
        /* Dot the vector with different matrix columns to get output elements. */
        for (npy_intp j = 0; j < dm; j++, ip2 += is2_m, op += os_m) {
            @TYPE@_@DOT@(ip1, is1_n, ip2, is2_n, op, dn, NULL);
#if @CHECK_PYERR@
            if (PyErr_Occurred()) {
                return;
            }
#endif
        }
    }
#if @USEBLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
    if (!npy_blas_supports_fpe()) {
        npy_clear_floatstatus_barrier((char*)args);
    }
#endif
}
/**end repeat**/
